{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "c532deb2-1ada-49e0-ad8d-4c42c4f9bfec",
   "metadata": {
    "ExecutionIndicator": {
     "show": true
    },
    "execution": {
     "iopub.execute_input": "2023-12-02T12:45:52.056534Z",
     "iopub.status.busy": "2023-12-02T12:45:52.056208Z",
     "iopub.status.idle": "2023-12-02T12:45:52.271540Z",
     "shell.execute_reply": "2023-12-02T12:45:52.270889Z",
     "shell.execute_reply.started": "2023-12-02T12:45:52.056517Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "!export HF_ENDPOINT=https://hf-mirror.com"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "7defb6d4-d42d-4e13-811b-cfebdd20d2ee",
   "metadata": {
    "ExecutionIndicator": {
     "show": true
    },
    "execution": {
     "iopub.execute_input": "2023-12-03T10:13:44.194308Z",
     "iopub.status.busy": "2023-12-03T10:13:44.194009Z",
     "iopub.status.idle": "2023-12-03T10:15:16.713072Z",
     "shell.execute_reply": "2023-12-03T10:15:16.712531Z",
     "shell.execute_reply.started": "2023-12-03T10:13:44.194289Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/envs/llava/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2023-12-03 18:13:47,809] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading checkpoint shards: 100%|██████████| 2/2 [01:15<00:00, 37.71s/it]\n"
     ]
    }
   ],
   "source": [
    "from llava.model.builder import load_pretrained_model\n",
    "from llava.mm_utils import get_model_name_from_path\n",
    "from llava.eval.run_llava import eval_model\n",
    "\n",
    "model_path = \"/mnt/workspace/llava-v1.5-7b\"\n",
    "\n",
    "tokenizer, model, image_processor, context_len = load_pretrained_model(\n",
    "    model_path=model_path,\n",
    "    model_base=None,\n",
    "    model_name=get_model_name_from_path(model_path)\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "4ec74850-2dd8-40bd-9b3a-b8520f9d167d",
   "metadata": {
    "ExecutionIndicator": {
     "show": true
    },
    "execution": {
     "iopub.execute_input": "2023-12-03T10:24:40.797032Z",
     "iopub.status.busy": "2023-12-03T10:24:40.796459Z",
     "iopub.status.idle": "2023-12-03T10:24:40.799448Z",
     "shell.execute_reply": "2023-12-03T10:24:40.799031Z",
     "shell.execute_reply.started": "2023-12-03T10:24:40.797012Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "from llava.conversation import conv_templates\n",
    "from llava.mm_utils import tokenizer_image_token\n",
    "from PIL import Image\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "6211c169-e5e5-463c-a0d4-13c58f0f0273",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-12-03T10:23:44.792510Z",
     "iopub.status.busy": "2023-12-03T10:23:44.792207Z",
     "iopub.status.idle": "2023-12-03T10:23:44.795001Z",
     "shell.execute_reply": "2023-12-03T10:23:44.794563Z",
     "shell.execute_reply.started": "2023-12-03T10:23:44.792493Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "image_file = '/mnt/workspace/LLaVA-1.5/notebook/photo.jpg'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "e9623bc2-ffae-4832-9fcf-de5ff880ac5a",
   "metadata": {
    "ExecutionIndicator": {
     "show": true
    },
    "execution": {
     "iopub.execute_input": "2023-12-03T10:30:09.473366Z",
     "iopub.status.busy": "2023-12-03T10:30:09.473052Z",
     "iopub.status.idle": "2023-12-03T10:30:09.477226Z",
     "shell.execute_reply": "2023-12-03T10:30:09.476659Z",
     "shell.execute_reply.started": "2023-12-03T10:30:09.473347Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Args(temperature=0.2, top_p=None, num_beams=1)\n"
     ]
    }
   ],
   "source": [
    "from dataclasses import dataclass\n",
    "\n",
    "@dataclass\n",
    "class Args:\n",
    "    temperature: float = 0.2\n",
    "    top_p: float = None\n",
    "    num_beams: int = 1\n",
    "\n",
    "args = Args()\n",
    "print(args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "28bcb14e-c6ce-4c21-9954-047f2deb99d2",
   "metadata": {
    "ExecutionIndicator": {
     "show": true
    },
    "execution": {
     "iopub.execute_input": "2023-12-03T10:30:12.403707Z",
     "iopub.status.busy": "2023-12-03T10:30:12.403411Z",
     "iopub.status.idle": "2023-12-03T10:30:21.410224Z",
     "shell.execute_reply": "2023-12-03T10:30:21.409732Z",
     "shell.execute_reply.started": "2023-12-03T10:30:12.403691Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The image depicts a small, clean, and well-organized bedroom with a bed and a desk. The bed is positioned against the left wall, and the desk is located on the right side of the room. A chair is placed in front of the desk, providing a comfortable seating area for work or study. \n",
      "\n",
      "In addition to the bed and desk, there is a TV mounted on the wall above the bed, and a laptop is placed on the desk. A few books can be seen on the desk, and a cup is located near the right edge of the desk. A teddy bear is also present in the room, adding a touch of warmth and personality to the space.\n"
     ]
    }
   ],
   "source": [
    "DEFAULT_IMAGE_TOKEN = \"<image>\"\n",
    "IMAGE_TOKEN_INDEX = -200\n",
    "qs = 'describe this image'\n",
    "qs = DEFAULT_IMAGE_TOKEN + '\\n' + qs\n",
    "conv = conv_templates['llava_v1'].copy()\n",
    "conv.append_message(conv.roles[0], qs)\n",
    "conv.append_message(conv.roles[1], None)\n",
    "prompt = conv.get_prompt()\n",
    "\n",
    "input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()\n",
    "\n",
    "image = Image.open(image_file)\n",
    "image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]\n",
    "\n",
    "stop_str = '<s>'\n",
    "with torch.inference_mode():\n",
    "    output_ids = model.generate(\n",
    "        input_ids,\n",
    "        images=image_tensor.unsqueeze(0).half().cuda(),\n",
    "        do_sample=True if args.temperature > 0 else False,\n",
    "        temperature=args.temperature,\n",
    "        top_p=args.top_p,\n",
    "        num_beams=args.num_beams,\n",
    "        # no_repeat_ngram_size=3,\n",
    "        max_new_tokens=1024,\n",
    "        use_cache=True)\n",
    "\n",
    "input_token_len = input_ids.shape[1]\n",
    "n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()\n",
    "if n_diff_input_output > 0:\n",
    "    print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')\n",
    "outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]\n",
    "outputs = outputs.strip()\n",
    "if outputs.endswith(stop_str):\n",
    "    outputs = outputs[:-len(stop_str)]\n",
    "outputs = outputs.strip()\n",
    "print(outputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41335d1f-0443-4451-b415-6c456c81350a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llava",
   "language": "python",
   "name": "llava"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
